import pandas as pd
from sklearn.model_selection import train_test_split
from collections import Counter
from utils.config import Config
root_path = 'D:/HeiMa/Pycharm/group4_nlp_project'
conf = Config(root_path)


# 读取数据
data = pd.read_csv(conf.clean_file)

# 提取所有不重复的cat类别，并保存为 class.csv
cat_classes = data["cat"].unique()
pd.DataFrame(cat_classes).to_csv("data/class.csv", index=False, header=False)

# 统计类别和标签的分布
cat_class_num = Counter(data["cat"])
label_class_num = Counter(data["label"])

print("类别分布：", cat_class_num)
print("标签分布：", label_class_num)

# 按照 80% 训练集, 20% 测试集 划分，分层抽样（保证 cat 和 label 的分布一致）
stratify_cols = data[["cat", "label"]]
train_data, test_data = train_test_split(
    data,
    test_size=0.2,
    random_state=42
)

# 保存为 CSV 文件
train_data.to_csv(conf.train_path, index=False)
test_data.to_csv(conf.test_path, index=False)

print("数据已成功划分为 train.csv (80%) 和 test.csv (20%)")